{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第4章 卷积神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 二维卷积层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "def corr2d(X, K):\n",
    "    # 行、列值\n",
    "    h, w = K.shape\n",
    "    # 卷积结果的存放位置\n",
    "    Y = torch.zeros((X.shape[0]-h+1, X.shape[1]-w+1))\n",
    "    for i in range(Y.shape[0]):\n",
    "        for j in range(Y.shape[1]):\n",
    "            Y[i, j] = (X[i:i+h, j:j+w] * K).sum()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[19., 25.],\n",
       "        [37., 43.]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])\n",
    "K = torch.tensor([[0, 1], [2, 3]])\n",
    "corr2d(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2D(nn.Module):\n",
    "    def __init__(self, kernel_size):\n",
    "        super(Conv2D, self).__init__()\n",
    "        self.weight = nn.Parameter(torch.randn(kernel_size))\n",
    "        self.bias = nn.Parameter(torch.randn(1))\n",
    "    def forward(self, x):\n",
    "        return corr2d(x, self.weight) + self.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.],\n",
       "        [1., 1., 0., 0., 0., 0., 1., 1.]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones(6, 8)\n",
    "X[:, 2:6] = 0\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "K = torch.tensor([[1, -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],\n",
       "        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = corr2d(X, K.float())\n",
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 5, loss 5.936\n",
      "Step 10, loss 1.648\n",
      "Step 15, loss 0.459\n",
      "Step 20, loss 0.128\n",
      "Step 25, loss 0.036\n",
      "Step 30, loss 0.010\n"
     ]
    }
   ],
   "source": [
    "conv2d = Conv2D(kernel_size=(1, 2))\n",
    "step = 30\n",
    "lr = 0.01\n",
    "for i in range(step):\n",
    "    Y_hat = conv2d(X)\n",
    "    l = ((Y_hat - Y)**2).sum()\n",
    "    l.backward()\n",
    "    # 梯度下降\n",
    "    conv2d.weight.data -= lr * conv2d.weight.grad\n",
    "    conv2d.bias.data -= lr * conv2d.bias.grad\n",
    "    # 梯度清零\n",
    "    conv2d.weight.grad.fill_(0)\n",
    "    conv2d.bias.grad.fill_(0)\n",
    "    if (i+1)%5 == 0:\n",
    "        print('Step %d, loss %.3f' % (i+1, l.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight:  tensor([[ 0.9746, -0.9748]])\n",
      "bias:  tensor([9.9195e-05])\n"
     ]
    }
   ],
   "source": [
    "print('weight: ', conv2d.weight.data)\n",
    "print('bias: ', conv2d.bias.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 填充和步幅"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 8])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义一个函数来计算卷积层。它对输入和输出做相应的升维和降维\n",
    "def comp_conv2d(conv2d, X):\n",
    "    # (1，1)代表批量大小和通道数(“多输入通道和多输出通道”一节将介绍)均为1\n",
    "    # (1, 1)+(8, 8)=(1, 1, 8, 8)\n",
    "    X = X.view((1, 1)+X.shape)\n",
    "    Y = conv2d(X)\n",
    "    # 排除不关心的前两维：批量和通道\n",
    "    return Y.view(Y.shape[2:])\n",
    "# 注意这里是两侧分别填充1行或列，所以在两侧一共填充2行或列\n",
    "conv2d = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1)\n",
    "X = torch.rand(8, 8)\n",
    "comp_conv2d(conv2d, X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 8])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用高为5、宽为3的卷积核。在高和宽两侧的填充数分别为2和1\n",
    "conv2d = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 3), padding=(2, 1))\n",
    "comp_conv2d(conv2d, X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 4])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv2d = nn.Conv2d(1, 1, kernel_size=3, padding=1, stride=2)\n",
    "comp_conv2d(conv2d, X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "conv2d = nn.Conv2d(1, 1, kernel_size=(3, 5), padding=(0, 1), stride=(3, 4))\n",
    "comp_conv2d(conv2d, X).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 多输入通道和多输出通道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "def corr2d_multi_in(X, K):\n",
    "    # 沿着X和K的第0维（通道维）分别计算再相加\n",
    "    res = d2l.corr2d(X[0, :, :], K[0, :, :])\n",
    "    for i in range(1, X.shape[0]):\n",
    "        res += d2l.corr2d(X[i, :, :], K[i, :, :])\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 56.,  72.],\n",
       "        [104., 120.]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([\n",
    "    [[0, 1, 2], [3, 4, 5], [6, 7, 8]], \n",
    "    [[1, 2, 3], [4, 5, 6], [7, 8, 9]]\n",
    "], dtype=torch.float)\n",
    "K = torch.tensor([\n",
    "    [[0, 1], [2, 3]], \n",
    "    [[1, 2], [3, 4]]\n",
    "], dtype=torch.float)\n",
    "corr2d_multi_in(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr2d_multi_in_out(X, K):\n",
    "    # 对K的第0维遍历，每次同输入X做互相关计算。所有结果使用stack函数合并在一起\n",
    "    return torch.stack([corr2d_multi_in(X, k) for k in K])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 2, 2, 2])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K = torch.stack([K, K+1, K+2])\n",
    "K.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 56.,  72.],\n",
       "         [104., 120.]],\n",
       "\n",
       "        [[ 76., 100.],\n",
       "         [148., 172.]],\n",
       "\n",
       "        [[ 96., 128.],\n",
       "         [192., 224.]]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "corr2d_multi_in_out(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr2d_multi_in_out_1x1(X, K):\n",
    "    # 通道、高、宽\n",
    "    c_i, h, w = X.shape\n",
    "    # 输出通道数目\n",
    "    c_o = K.shape[0]\n",
    "    # 通道、高*宽\n",
    "    # 一个通道是一个特征\n",
    "    X = X.view(c_i, h*w)\n",
    "    K = K.view(c_o, c_i)\n",
    "    # 全连接层的矩阵乘法\n",
    "    Y = torch.mm(K, X)\n",
    "    return Y.view(c_o, h, w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.rand(3, 3, 3)\n",
    "K = torch.rand(2, 3, 1, 1)\n",
    "Y1 = corr2d_multi_in_out_1x1(X, K)\n",
    "Y2 = corr2d_multi_in_out(X, K)\n",
    "(Y1-Y2).norm().item() < 1e-6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.4 池化层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pool2d(X, pool_size, mode='max'):\n",
    "    X = X.float()\n",
    "    p_h, p_w = pool_size\n",
    "    # 存储池化计算结果\n",
    "    Y = torch.zeros(X.shape[0]-p_h+1, X.shape[1]-p_w+1)\n",
    "    for i in range(Y.shape[0]):\n",
    "        for j in range(Y.shape[1]):\n",
    "            if mode=='max':\n",
    "                Y[i, j] = X[i: i+p_h, j: j+p_w].max()\n",
    "            elif mode=='avg':\n",
    "                Y[i, j] = X[i: i+p_h, j: j+p_w].mean()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4., 5.],\n",
       "        [7., 8.]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([\n",
    "    [0, 1, 2], \n",
    "    [3, 4, 5], \n",
    "    [6, 7, 8]\n",
    "])\n",
    "pool2d(X, (2, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[2., 3.],\n",
       "        [5., 6.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool2d(X, (2, 2), mode='avg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.]]]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.arange(16, dtype=torch.float).view((1, 1, 4, 4))\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[10.]]]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool2d = nn.MaxPool2d(3)\n",
    "pool2d(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 5.,  7.],\n",
       "          [13., 15.]]]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
    "pool2d(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 1.,  3.],\n",
       "          [ 9., 11.],\n",
       "          [13., 15.]]]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool2d = nn.MaxPool2d((2, 4), padding=(1, 2), stride=(2, 3))\n",
    "pool2d(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 0.,  1.,  2.,  3.],\n",
       "          [ 4.,  5.,  6.,  7.],\n",
       "          [ 8.,  9., 10., 11.],\n",
       "          [12., 13., 14., 15.]],\n",
       "\n",
       "         [[ 1.,  2.,  3.,  4.],\n",
       "          [ 5.,  6.,  7.,  8.],\n",
       "          [ 9., 10., 11., 12.],\n",
       "          [13., 14., 15., 16.]]]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.cat((X, X+1), dim=1)\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 5.,  7.],\n",
       "          [13., 15.]],\n",
       "\n",
       "         [[ 6.,  8.],\n",
       "          [14., 16.]]]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pool2d = nn.MaxPool2d(3, padding=1, stride=2)\n",
    "pool2d(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.5 LeNet-卷积神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn, optim\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            # in_channels, out_channels, kernel_size\n",
    "            nn.Conv2d(1, 6, 5),\n",
    "            nn.Sigmoid(),\n",
    "            # kernel_size, stride\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2)\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(16*4*4, 120), \n",
    "            nn.Sigmoid(), \n",
    "            nn.Linear(120, 84), \n",
    "            nn.Sigmoid(), \n",
    "            nn.Linear(84, 10)\n",
    "        )\n",
    "    def forward(self, img):\n",
    "        feature = self.conv(img)\n",
    "        output = self.fc(feature.view(img.shape[0], -1))\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LeNet(\n",
      "  (conv): Sequential(\n",
      "    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "    (1): Sigmoid()\n",
      "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "    (4): Sigmoid()\n",
      "    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (fc): Sequential(\n",
      "    (0): Linear(in_features=256, out_features=120, bias=True)\n",
      "    (1): Sigmoid()\n",
      "    (2): Linear(in_features=120, out_features=84, bias=True)\n",
      "    (3): Sigmoid()\n",
      "    (4): Linear(in_features=84, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = LeNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.8646, train acc 0.318, test acc 0.587, time 4.7 sec\n",
      "epoch 2, loss 0.9301, train acc 0.646, test acc 0.685, time 4.3 sec\n",
      "epoch 3, loss 0.7464, train acc 0.720, test acc 0.727, time 4.2 sec\n",
      "epoch 4, loss 0.6674, train acc 0.741, test acc 0.744, time 4.2 sec\n",
      "epoch 5, loss 0.6137, train acc 0.757, test acc 0.760, time 4.3 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.6 AlexNet-深度卷积神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "class AlexNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            # in_channels, out_channels, kernel_size, stride\n",
    "            nn.Conv2d(1, 96, 11, 4), \n",
    "            nn.ReLU(), \n",
    "            # kernel_size, stride\n",
    "            nn.MaxPool2d(3, 2), \n",
    "            # 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数\n",
    "            nn.Conv2d(96, 256, 5, 1, 2),\n",
    "            nn.ReLU(), \n",
    "            nn.MaxPool2d(3, 2), \n",
    "            # 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。\n",
    "            # 前两个卷积层后不使用池化层来减小输入的高和宽\n",
    "            nn.Conv2d(256, 384, 3, 1, 1),\n",
    "            nn.ReLU(), \n",
    "            nn.Conv2d(384, 384, 3, 1, 1), \n",
    "            nn.ReLU(), \n",
    "            nn.Conv2d(384, 256, 3, 1, 1), \n",
    "            nn.ReLU(), \n",
    "            nn.MaxPool2d(3, 2)\n",
    "        )\n",
    "        # 这里全连接层的输出个数比LeNet中的大数倍。使用丢弃层来缓解过拟合\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(256*5*5, 4096), \n",
    "            nn.ReLU(), \n",
    "            nn.Dropout(0.5), \n",
    "            nn.Linear(4096, 4096), \n",
    "            nn.ReLU(), \n",
    "            nn.Dropout(0.5), \n",
    "            # 输出层。由于这里使用Fashion-MNIST，所以用类别数为10，而非论文中的1000\n",
    "            nn.Linear(4096, 10)\n",
    "        )\n",
    "    def forward(self, img):\n",
    "        feature = self.conv(img)\n",
    "        output = self.fc(feature.view(img.shape[0], -1))\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AlexNet(\n",
      "  (conv): Sequential(\n",
      "    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))\n",
      "    (1): ReLU()\n",
      "    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (4): ReLU()\n",
      "    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (7): ReLU()\n",
      "    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (9): ReLU()\n",
      "    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU()\n",
      "    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (fc): Sequential(\n",
      "    (0): Linear(in_features=6400, out_features=4096, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Dropout(p=0.5)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU()\n",
      "    (5): Dropout(p=0.5)\n",
      "    (6): Linear(in_features=4096, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = AlexNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.6961, train acc 0.732, test acc 0.841, time 26.5 sec\n",
      "epoch 2, loss 0.3328, train acc 0.874, test acc 0.884, time 26.5 sec\n",
      "epoch 3, loss 0.2833, train acc 0.893, test acc 0.894, time 26.1 sec\n",
      "epoch 4, loss 0.2503, train acc 0.907, test acc 0.903, time 26.5 sec\n",
      "epoch 5, loss 0.2291, train acc 0.915, test acc 0.906, time 26.6 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.7 VGG-使用重复元素的网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vgg_block(num_convs, in_channels, out_channels):\n",
    "    blk = []\n",
    "    for i in range(num_convs):\n",
    "        if i==0:\n",
    "            blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n",
    "        else:\n",
    "            blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n",
    "        blk.append(nn.ReLU())\n",
    "    # 这里会使宽高减半\n",
    "    blk.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "    return nn.Sequential(*blk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_arch = ((1, 1, 64), (1, 64, 128), (2, 128, 256), (2, 256, 512), (2, 512, 512))\n",
    "# 经过5个vgg_block，宽高会减半5次，变成224/2**5=224/32=7\n",
    "# c * w * h\n",
    "fc_features = 512*7*7\n",
    "fc_hidden_units = 4096"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vgg(conv_arch, fc_features, fc_hidden_units=4096):\n",
    "    net = nn.Sequential()\n",
    "    # 卷积层部分\n",
    "    for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):\n",
    "        # 每经过一个vgg_block都会使宽高减半\n",
    "        net.add_module('vgg_block_' + str(i+1), vgg_block(num_convs, in_channels, out_channels))\n",
    "    # 全连接部分\n",
    "    net.add_module('fc', nn.Sequential(\n",
    "        d2l.FlattenLayer(), \n",
    "        nn.Linear(fc_features, fc_hidden_units), \n",
    "        nn.ReLU(), \n",
    "        nn.Dropout(0.5), \n",
    "        nn.Linear(fc_hidden_units, fc_hidden_units), \n",
    "        nn.ReLU(), \n",
    "        nn.Dropout(0.5), \n",
    "        nn.Linear(fc_hidden_units, 10)\n",
    "    ))\n",
    "    return net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vgg_block_1 output shape:  torch.Size([1, 64, 112, 112])\n",
      "vgg_block_2 output shape:  torch.Size([1, 128, 56, 56])\n",
      "vgg_block_3 output shape:  torch.Size([1, 256, 28, 28])\n",
      "vgg_block_4 output shape:  torch.Size([1, 512, 14, 14])\n",
      "vgg_block_5 output shape:  torch.Size([1, 512, 7, 7])\n",
      "fc output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "net = vgg(conv_arch, fc_features, fc_hidden_units)\n",
    "X = torch.rand(1, 1, 224, 224)\n",
    "# named_children获取一级子模块及其名字（named_modules会返回所有子模块，包括子模块的子模块）\n",
    "for name, blk in net.named_children():\n",
    "    X = blk(X)\n",
    "    print(name, 'output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratio = 4\n",
    "small_conv_arch = ((1, 1, 64//ratio), (1, 64//ratio, 128//ratio), (2, 128//ratio, 256//ratio), (2, 256//ratio, 512//ratio), (2, 512//ratio, 512//ratio))\n",
    "net = vgg(small_conv_arch, fc_features//ratio, fc_hidden_units//ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.7833, train acc 0.702, test acc 0.839, time 44.3 sec\n",
      "epoch 2, loss 0.3814, train acc 0.860, test acc 0.873, time 44.5 sec\n",
      "epoch 3, loss 0.3125, train acc 0.885, test acc 0.895, time 44.5 sec\n",
      "epoch 4, loss 0.2730, train acc 0.899, test acc 0.904, time 44.6 sec\n",
      "epoch 5, loss 0.2428, train acc 0.911, test acc 0.912, time 44.5 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.8 NiN-网络中的网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "def nin_block(in_channels, out_channels, kernel_size, stride, padding):\n",
    "    blk = nn.Sequential(\n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), \n",
    "        nn.ReLU(), \n",
    "        # 相当于全连接层\n",
    "        nn.Conv2d(out_channels, out_channels, kernel_size=1), \n",
    "        nn.ReLU(), \n",
    "        nn.Conv2d(out_channels, out_channels, kernel_size=1), \n",
    "        nn.ReLU()\n",
    "    )\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "class GlobalAvgPool2d(nn.Module):\n",
    "    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现\n",
    "    def __init__(self):\n",
    "        super(GlobalAvgPool2d, self).__init__()\n",
    "    def forward(self, x):\n",
    "        return F.avg_pool2d(x, kernel_size=x.size()[2:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nin_block(1, 96, kernel_size=11, stride=4, padding=0), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2), \n",
    "    nin_block(96, 256, kernel_size=5, stride=1, padding=2), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nin_block(256, 384, kernel_size=3, stride=1, padding=1), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nn.Dropout(0.5), \n",
    "    # 标签类别数是10\n",
    "    nin_block(384, 10, kernel_size=3, stride=1, padding=1), \n",
    "    GlobalAvgPool2d(), \n",
    "    # 将四维的输出转成二维的输出，其形状为(批量大小, 10)\n",
    "    d2l.FlattenLayer()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 output shape:  torch.Size([1, 96, 54, 54])\n",
      "1 output shape:  torch.Size([1, 96, 26, 26])\n",
      "2 output shape:  torch.Size([1, 256, 26, 26])\n",
      "3 output shape:  torch.Size([1, 256, 12, 12])\n",
      "4 output shape:  torch.Size([1, 384, 12, 12])\n",
      "5 output shape:  torch.Size([1, 384, 5, 5])\n",
      "6 output shape:  torch.Size([1, 384, 5, 5])\n",
      "7 output shape:  torch.Size([1, 10, 5, 5])\n",
      "8 output shape:  torch.Size([1, 10, 1, 1])\n",
      "9 output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(1, 1, 224, 224)\n",
    "for name, blk in net.named_children():\n",
    "    X = blk(X)\n",
    "    print(name, 'output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.2753, train acc 0.524, test acc 0.726, time 31.9 sec\n",
      "epoch 2, loss 0.6357, train acc 0.767, test acc 0.805, time 32.0 sec\n",
      "epoch 3, loss 0.5100, train acc 0.815, test acc 0.832, time 32.2 sec\n",
      "epoch 4, loss 0.4403, train acc 0.840, test acc 0.843, time 32.0 sec\n",
      "epoch 5, loss 0.3973, train acc 0.854, test acc 0.856, time 32.0 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "lr, num_epochs = 0.002, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.9 GoogLeNet-含并行连结的网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "class Inception(nn.Module):\n",
    "    # c1-c4为每条路线里的层的输出通道数\n",
    "    def __init__(self, in_c, c1, c2, c3, c4):\n",
    "        super(Inception, self).__init__()\n",
    "        # 线路1，单1*1卷积层\n",
    "        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)\n",
    "        # 线路2，1*1卷积层后接3*3卷积层\n",
    "        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)\n",
    "        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)\n",
    "        # 线路3，1*1卷积层后接5*5卷积层\n",
    "        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)\n",
    "        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)\n",
    "        # 线路4，3*3最大池化层后接1*1卷积层\n",
    "        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)\n",
    "        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)\n",
    "    def forward(self, x):\n",
    "        p1 = F.relu(self.p1_1(x))\n",
    "        p2 = F.relu(self.p2_2(self.p2_1(x)))\n",
    "        p3 = F.relu(self.p3_2(self.p3_1(x)))\n",
    "        p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
    "        # 在通道维上连结输出\n",
    "        return torch.cat((p1, p2, p3, p4), dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1 = nn.Sequential(\n",
    "    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), \n",
    "    nn.ReLU(), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "b2 = nn.Sequential(\n",
    "    nn.Conv2d(64, 64, kernel_size=1), \n",
    "    nn.ReLU(), \n",
    "    nn.Conv2d(64, 192, kernel_size=3, padding=1), \n",
    "    nn.ReLU(), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "b3 = nn.Sequential(\n",
    "    Inception(192, 64, (96, 128), (16, 32), 32), \n",
    "    Inception(256, 128, (128, 192), (32, 96), 64),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "b4 = nn.Sequential(\n",
    "    Inception(480, 192, (96, 208), (16, 48), 64),\n",
    "    Inception(512, 160, (112, 224), (24, 64), 64),\n",
    "    Inception(512, 128, (128, 256), (24, 64), 64),\n",
    "    Inception(512, 112, (144, 288), (32, 64), 64),\n",
    "    Inception(528, 256, (160, 320), (32, 128), 128),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "b5 = nn.Sequential(\n",
    "    Inception(832, 256, (160, 320), (32, 128), 128),\n",
    "    Inception(832, 384, (192, 384), (48, 128), 128),\n",
    "    GlobalAvgPool2d()\n",
    ")\n",
    "net = nn.Sequential(\n",
    "    b1, b2, b3, b4, b5, \n",
    "    d2l.FlattenLayer(), \n",
    "    nn.Linear(1024, 10)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output shape:  torch.Size([1, 64, 24, 24])\n",
      "output shape:  torch.Size([1, 192, 12, 12])\n",
      "output shape:  torch.Size([1, 480, 6, 6])\n",
      "output shape:  torch.Size([1, 832, 3, 3])\n",
      "output shape:  torch.Size([1, 1024, 1, 1])\n",
      "output shape:  torch.Size([1, 1024])\n",
      "output shape:  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand(1, 1, 96, 96)\n",
    "for blk in net.children():\n",
    "    X = blk(X)\n",
    "    print('output shape: ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 2.4547, train acc 0.099, test acc 0.100, time 23.0 sec\n",
      "epoch 2, loss 1.6381, train acc 0.344, test acc 0.496, time 23.1 sec\n",
      "epoch 3, loss 0.7035, train acc 0.733, test acc 0.726, time 23.1 sec\n",
      "epoch 4, loss 0.5071, train acc 0.815, test acc 0.824, time 23.1 sec\n",
      "epoch 5, loss 0.4271, train acc 0.842, test acc 0.826, time 23.1 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.10 批量归一化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
    "    # 判断当前模式是训练模式还是预测模式\n",
    "    if not is_training:\n",
    "        # 如果是在预测模式下，直接使用传入的移动平均所得的均值和方差\n",
    "        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
    "    else:\n",
    "        # 前一层需要为全连接层或卷积层\n",
    "        assert len(X.shape) in (2, 4)\n",
    "        # 全连接层\n",
    "        if len(X.shape) == 2:\n",
    "            # 沿纵向求均值，(1, 特征个数)\n",
    "            # 注意：逐特征求均值\n",
    "            mean = X.mean(dim=0)\n",
    "            # 广播\n",
    "            var = ((X - mean) ** 2).mean(dim=0)\n",
    "        else:\n",
    "            # 使用二维卷积层的情况，计算通道维上（axis=1）的均值和方差。这里我们需要保持\n",
    "            # X的形状以便后面可以做广播运算\n",
    "            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "        # 训练模式下用当前的均值和方差做标准化\n",
    "        X_hat = (X - mean) / torch.sqrt(var + eps)\n",
    "        # 一阶指数平滑算法\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
    "    # 拉伸和偏移\n",
    "    Y = gamma * X_hat + beta\n",
    "    return Y, moving_mean, moving_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNorm(nn.Module):\n",
    "    def __init__(self, num_features, num_dims):\n",
    "        super(BatchNorm, self).__init__()\n",
    "        # 全连接层\n",
    "        if num_dims==2:\n",
    "            shape = (1, num_features)\n",
    "        # 卷积层\n",
    "        else:\n",
    "            shape = (1, num_features, 1, 1)\n",
    "        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成0和1\n",
    "        self.gamma = nn.Parameter(torch.ones(shape))\n",
    "        self.beta = nn.Parameter(torch.zeros(shape))\n",
    "        # 不参与求梯度和迭代的变量，全在内存上初始化成0\n",
    "        self.moving_mean = torch.zeros(shape)\n",
    "        self.moving_var = torch.zeros(shape)\n",
    "    def forward(self, X):\n",
    "        # 如果X不在显存上，将moving_mean和moving_var复制到X所在显存上\n",
    "        if self.moving_mean.device != X.device:\n",
    "            self.moving_mean = self.moving_mean.to(X.device)\n",
    "            self.moving_var = self.moving_var.to(X.device)\n",
    "        # 保存更新过的moving_mean和moving_var\n",
    "        # Module实例的traning属性默认为true, 调用.eval()后设成false\n",
    "        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, self.moving_mean, self.moving_var, eps=1e-5, momentum=0.9)\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "net = nn.Sequential(\n",
    "    # in_channels, out_channels, kernel_size\n",
    "    nn.Conv2d(1, 6, 5), \n",
    "    BatchNorm(6, num_dims=4), \n",
    "    nn.Sigmoid(), \n",
    "    # kernel_size, stride\n",
    "    nn.MaxPool2d(2, 2), \n",
    "    nn.Conv2d(6, 16, 5), \n",
    "    BatchNorm(16, num_dims=4), \n",
    "    nn.Sigmoid(), \n",
    "    nn.MaxPool2d(2, 2), \n",
    "    d2l.FlattenLayer(), \n",
    "    nn.Linear(16*4*4, 120), \n",
    "    BatchNorm(120, num_dims=2), \n",
    "    nn.Sigmoid(), \n",
    "    nn.Linear(120, 84), \n",
    "    BatchNorm(84, num_dims=2), \n",
    "    nn.Sigmoid(), \n",
    "    nn.Linear(84, 10)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.0072, train acc 0.788, test acc 0.827, time 6.7 sec\n",
      "epoch 2, loss 0.4596, train acc 0.862, test acc 0.836, time 6.5 sec\n",
      "epoch 3, loss 0.3726, train acc 0.876, test acc 0.852, time 6.5 sec\n",
      "epoch 4, loss 0.3339, train acc 0.886, test acc 0.866, time 6.6 sec\n",
      "epoch 5, loss 0.3132, train acc 0.890, test acc 0.855, time 6.3 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1.1173, 0.9368, 0.9685, 1.1043, 1.0498, 0.9036], device='cuda:0',\n",
       "        grad_fn=<ViewBackward>),\n",
       " tensor([ 0.2380,  0.0031, -0.4098,  0.3822, -0.6539, -0.1376], device='cuda:0',\n",
       "        grad_fn=<ViewBackward>))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[1].gamma.view((-1,)), net[1].beta.view((-1,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    # in_channels, out_channels, kernel_size\n",
    "    nn.Conv2d(1, 6, 5), \n",
    "    nn.BatchNorm2d(6), \n",
    "    nn.Sigmoid(), \n",
    "    # kernel_size, stride\n",
    "    nn.MaxPool2d(2, 2), \n",
    "    nn.Conv2d(6, 16, 5), \n",
    "    nn.BatchNorm2d(16), \n",
    "    nn.Sigmoid(), \n",
    "    nn.MaxPool2d(2, 2), \n",
    "    d2l.FlattenLayer(), \n",
    "    nn.Linear(16*4*4, 120), \n",
    "    nn.BatchNorm1d(120), \n",
    "    nn.Sigmoid(), \n",
    "    nn.Linear(120, 84), \n",
    "    nn.BatchNorm1d(84), \n",
    "    nn.Sigmoid(), \n",
    "    nn.Linear(84, 10)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.3550, train acc 0.765, test acc 0.806, time 5.2 sec\n",
      "epoch 2, loss 0.5887, train acc 0.857, test acc 0.760, time 5.9 sec\n",
      "epoch 3, loss 0.4131, train acc 0.875, test acc 0.810, time 6.6 sec\n",
      "epoch 4, loss 0.3566, train acc 0.883, test acc 0.837, time 6.0 sec\n",
      "epoch 5, loss 0.3254, train acc 0.891, test acc 0.821, time 5.4 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.11 ResNet-残差网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "class Residual(nn.Module):\n",
    "    # 输入通道数、输出通道数、是否使用1x1卷积核、步长\n",
    "    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):\n",
    "        super(Residual, self).__init__()\n",
    "        # 3x3搭配1步长，特征图大小不变\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)\n",
    "        if use_1x1conv:\n",
    "            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "    def forward(self, X):\n",
    "        Y = F.relu(self.bn1(self.conv1(X)))\n",
    "        Y = self.bn2(self.conv2(Y))\n",
    "        if self.conv3:\n",
    "            X = self.conv3(X)\n",
    "        return F.relu(Y+X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 6, 6])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = Residual(3, 3)\n",
    "X = torch.rand((4, 3, 6, 6))\n",
    "blk(X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 6, 3, 3])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = Residual(3, 6, use_1x1conv=True, stride=2)\n",
    "X = torch.rand((4, 3, 6, 6))\n",
    "blk(X).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), \n",
    "    nn.BatchNorm2d(64), \n",
    "    nn.ReLU(), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def resnet_block(in_channels, out_channels, num_residuals, first_block=False):\n",
    "    if first_block:\n",
    "        # 第一个模块的通道数同输入通道数一致\n",
    "        assert in_channels == out_channels\n",
    "    blk = []\n",
    "    for i in range(num_residuals):\n",
    "        if i==0 and not first_block:\n",
    "            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))\n",
    "        else:\n",
    "            blk.append(Residual(out_channels, out_channels))\n",
    "    return nn.Sequential(*blk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.add_module('resnet_block1', resnet_block(64, 64, 2, first_block=True))\n",
    "net.add_module('resnet_block2', resnet_block(64, 128, 2))\n",
    "net.add_module('resnet_block3', resnet_block(128, 256, 2))\n",
    "net.add_module('resnet_block4', resnet_block(256, 512, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GlobalAvgPool2d的输出: (Batch, 512, 1, 1)\n",
    "net.add_module('global_avg_pool', d2l.GlobalAvgPool2d())\n",
    "net.add_module('fc', nn.Sequential(d2l.FlattenLayer(), nn.Linear(512, 10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "1  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "2  output shape:\t torch.Size([1, 64, 112, 112])\n",
      "3  output shape:\t torch.Size([1, 64, 56, 56])\n",
      "resnet_block1  output shape:\t torch.Size([1, 64, 56, 56])\n",
      "resnet_block2  output shape:\t torch.Size([1, 128, 28, 28])\n",
      "resnet_block3  output shape:\t torch.Size([1, 256, 14, 14])\n",
      "resnet_block4  output shape:\t torch.Size([1, 512, 7, 7])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 512, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((1, 1, 224, 224))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.4296, train acc 0.843, test acc 0.848, time 10.8 sec\n",
      "epoch 2, loss 0.2982, train acc 0.890, test acc 0.886, time 10.7 sec\n",
      "epoch 3, loss 0.2583, train acc 0.904, test acc 0.873, time 10.8 sec\n",
      "epoch 4, loss 0.2299, train acc 0.915, test acc 0.899, time 10.9 sec\n",
      "epoch 5, loss 0.2084, train acc 0.923, test acc 0.903, time 10.9 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.12 DenseNet-稠密连接网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "def conv_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels), \n",
    "        nn.ReLU(), \n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n",
    "    )\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseBlock(nn.Module):\n",
    "    def __init__(self, num_convs, in_channels, out_channels):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        net = []\n",
    "        for i in range(num_convs):\n",
    "            in_c = in_channels + i*out_channels\n",
    "            net.append(conv_block(in_c, out_channels))\n",
    "        self.net = nn.ModuleList(net)\n",
    "        # 计算输出通道数\n",
    "        self.out_channels = in_channels + num_convs*out_channels\n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            # 在通道维上将输入和输出连结\n",
    "            X = torch.cat((X, Y), dim=1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 23, 8, 8])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = DenseBlock(2, 3, 10)\n",
    "X = torch.rand(4, 3, 8, 8)\n",
    "Y = blk(X)\n",
    "Y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transition_block(in_channels, out_channels):\n",
    "    blk = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channels), \n",
    "        nn.ReLU(), \n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size=1), \n",
    "        nn.AvgPool2d(kernel_size=2, stride=2)\n",
    "    )\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 4, 4])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = transition_block(23, 10)\n",
    "blk(Y).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), \n",
    "    nn.BatchNorm2d(64), \n",
    "    nn.ReLU(), \n",
    "    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# num_channels为当前的通道数\n",
    "num_channels, growth_rate = 64, 32\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    DB = DenseBlock(num_convs, num_channels, growth_rate)\n",
    "    net.add_module('DenseBlock_%d' % i, DB)\n",
    "    # 上一个稠密块的输出通道数\n",
    "    num_channels = DB.out_channels\n",
    "    # 在稠密块之间加入通道数减半的过渡层\n",
    "    if i != len(num_convs_in_dense_blocks)-1:\n",
    "        net.add_module('transition_block_%d' % i, transition_block(num_channels, num_channels//2))\n",
    "        num_channels = num_channels // 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "net.add_module('BN', nn.BatchNorm2d(num_channels))\n",
    "net.add_module('relu', nn.ReLU())\n",
    "# GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)\n",
    "net.add_module('global_avg_pool', d2l.GlobalAvgPool2d())\n",
    "net.add_module('fc', nn.Sequential(\n",
    "    d2l.FlattenLayer(), \n",
    "    nn.Linear(num_channels, 10)\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "1  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "2  output shape:\t torch.Size([1, 64, 48, 48])\n",
      "3  output shape:\t torch.Size([1, 64, 24, 24])\n",
      "DenseBlock_0  output shape:\t torch.Size([1, 192, 24, 24])\n",
      "transition_block_0  output shape:\t torch.Size([1, 96, 12, 12])\n",
      "DenseBlock_1  output shape:\t torch.Size([1, 224, 12, 12])\n",
      "transition_block_1  output shape:\t torch.Size([1, 112, 6, 6])\n",
      "DenseBlock_2  output shape:\t torch.Size([1, 240, 6, 6])\n",
      "transition_block_2  output shape:\t torch.Size([1, 120, 3, 3])\n",
      "DenseBlock_3  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "BN  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "relu  output shape:\t torch.Size([1, 248, 3, 3])\n",
      "global_avg_pool  output shape:\t torch.Size([1, 248, 1, 1])\n",
      "fc  output shape:\t torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "X = torch.rand((1, 1, 96, 96))\n",
    "for name, layer in net.named_children():\n",
    "    X = layer(X)\n",
    "    print(name, ' output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.5218, train acc 0.828, test acc 0.833, time 12.7 sec\n",
      "epoch 2, loss 0.3087, train acc 0.887, test acc 0.869, time 12.4 sec\n",
      "epoch 3, loss 0.2656, train acc 0.902, test acc 0.890, time 12.5 sec\n",
      "epoch 4, loss 0.2360, train acc 0.914, test acc 0.894, time 12.2 sec\n",
      "epoch 5, loss 0.2186, train acc 0.919, test acc 0.871, time 12.3 sec\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:PyTorch] *",
   "language": "python",
   "name": "conda-env-PyTorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
